#!/usr/bin/env python3
"""
Comprehensive examples for loading and plotting Lake Kinneret multi-sensor data

This script demonstrates how to load and visualize ALL datasets in the
Lake Kinneret 2017-2023 repository, including:
  1. Fluoroprobe in-situ measurements
  2. Laboratory chlorophyll measurements
  3. Satellite chlorophyll retrievals
  4. Meteorological observations
  5. Machine learning phytoplankton predictions
  6. Validation matchups and statistics

Requirements:
    - pandas
    - matplotlib
    - numpy
    - glob (standard library)

Usage:
    python load_and_plot_examples.py
"""

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from glob import glob

# Set the data directory (adjust path as needed)
DATA_DIR = Path('../')  # Assumes script is run from examples/ directory
OUTPUT_DIR = Path('output')  # Output directory for generated plots

# Create output directory if it doesn't exist
OUTPUT_DIR.mkdir(exist_ok=True)

print("=" * 80)
print("Lake Kinneret Multi-Sensor Dataset (2017-2023)")
print("Comprehensive Loading and Plotting Examples")
print("=" * 80)

# =============================================================================
# 1. FLUOROPROBE IN-SITU MEASUREMENTS
# =============================================================================
print("\n" + "=" * 80)
print("1. FLUOROPROBE IN-SITU MEASUREMENTS")
print("=" * 80)

def load_fluoroprobe():
    """Load fluoroprobe in-situ phytoplankton measurements"""
    print("\nLoading fluoroprobe data...")
    fp_file = DATA_DIR / 'fluoroprobe' / 'fluoroprobe_2017_2023_clean.csv'
    df = pd.read_csv(fp_file, parse_dates=['date'])
    print(f"Loaded {len(df):,} observations")
    print(f"Stations: {sorted(df['station'].unique())}")
    print(f"Date range: {df['date'].min()} to {df['date'].max()}")
    print(f"Depth range: {df['depth'].min():.1f} - {df['depth'].max():.1f} m")
    return df

fluoroprobe = load_fluoroprobe()

def plot_fluoroprobe_depth_profile(df, station='A', date=None):
    """Plot vertical fluorescence profile for a specific date"""
    print("\n  Plotting fluoroprobe depth profile...")

    # Get a representative date if not specified
    if date is None:
        station_data = df[df['station'] == station].copy()
        station_data['date_only'] = station_data['date'].dt.date
        date_counts = station_data.groupby('date_only').size()
        good_dates = date_counts[date_counts > 30].index  # Dates with >30 depth measurements
        if len(good_dates) == 0:
            good_dates = date_counts.index
        date = sorted(good_dates)[len(good_dates)//2]  # Middle date

    # Filter for specific station and date
    profile = df[(df['station'] == station) & (df['date'].dt.date == date)].copy()

    if len(profile) == 0:
        print(f"    Warning: No data for Station {station} on {date}")
        return

    # Sort by depth and aggregate to remove duplicates
    profile = profile.sort_values('depth')
    profile = profile.groupby('depth').agg({
        'led 3 525 nm': 'mean',
        'led 7  590 nm': 'mean',
        'temp sample': 'mean'
    }).reset_index()

    # Create figure
    fig, axes = plt.subplots(1, 3, figsize=(15, 8), sharey=True)

    # Plot 1: 525 nm fluorescence
    axes[0].plot(profile['led 3 525 nm'], profile['depth'], 'go-',
                 linewidth=2.5, markersize=6, markerfacecolor='lightgreen',
                 markeredgecolor='darkgreen', markeredgewidth=1.5)
    axes[0].invert_yaxis()
    axes[0].set_xlabel('Fluorescence (RFU)', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('Depth (m)', fontsize=12, fontweight='bold')
    axes[0].grid(True, alpha=0.4, linestyle='--')
    axes[0].set_title('525 nm', fontweight='bold', fontsize=13)
    axes[0].tick_params(labelsize=11)

    # Plot 2: 590 nm fluorescence
    axes[1].plot(profile['led 7  590 nm'], profile['depth'], 'bo-',
                 linewidth=2.5, markersize=6, markerfacecolor='lightblue',
                 markeredgecolor='darkblue', markeredgewidth=1.5)
    axes[1].set_xlabel('Fluorescence (RFU)', fontsize=12, fontweight='bold')
    axes[1].grid(True, alpha=0.4, linestyle='--')
    axes[1].set_title('590 nm', fontweight='bold', fontsize=13)
    axes[1].tick_params(labelsize=11)

    # Plot 3: Temperature
    axes[2].plot(profile['temp sample'], profile['depth'], 'ro-',
                 linewidth=2.5, markersize=6, markerfacecolor='lightcoral',
                 markeredgecolor='darkred', markeredgewidth=1.5)
    axes[2].set_xlabel('Water Temperature (C)', fontsize=12, fontweight='bold')
    axes[2].grid(True, alpha=0.4, linestyle='--')
    axes[2].set_title('Temperature', fontweight='bold', fontsize=13)
    axes[2].tick_params(labelsize=11)

    fig.suptitle(f'Lake Kinneret Station {station} - Vertical Profile\n{date}',
                 fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '1_fluoroprobe_depth_profile.png', dpi=300, bbox_inches='tight')
    print(f"    Saved: output/1_fluoroprobe_depth_profile.png")
    plt.close()

def plot_fluoroprobe_timeseries(df, station='A'):
    """Plot surface fluorescence time series"""
    print("\n  Plotting fluoroprobe time series...")

    # Filter for surface waters (<2m depth)
    surface = df[(df['station'] == station) & (df['depth'] < 2)].copy()

    # Aggregate by date (mean of surface measurements)
    daily = surface.groupby(surface['date'].dt.date).agg({
        'led 3 525 nm': 'mean',
        'led 7  590 nm': 'mean',
        'temp sample': 'mean'
    }).reset_index()
    daily['date'] = pd.to_datetime(daily['date'])
    daily = daily.sort_values('date')

    fig, axes = plt.subplots(2, 1, figsize=(16, 9), sharex=True)

    # Plot 1: Phytoplankton groups
    axes[0].plot(daily['date'], daily['led 3 525 nm'], 'g-', linewidth=2,
                 label='Green Algae (525 nm)', alpha=0.8)
    axes[0].plot(daily['date'], daily['led 7  590 nm'], 'b-', linewidth=2,
                 label='Cyanobacteria (590 nm)', alpha=0.8)
    axes[0].set_ylabel('Fluorescence (RFU)', fontsize=13, fontweight='bold')
    axes[0].legend(loc='upper right', fontsize=12, framealpha=0.95, edgecolor='black')
    axes[0].grid(True, alpha=0.3)
    axes[0].set_title(f'Lake Kinneret Station {station} - Surface Phytoplankton Dynamics (2017-2023)',
                      fontweight='bold', fontsize=14)
    axes[0].tick_params(labelsize=11)
    axes[0].set_ylim(bottom=0)

    # Plot 2: Temperature
    axes[1].plot(daily['date'], daily['temp sample'], 'r-', linewidth=2, alpha=0.8)
    axes[1].fill_between(daily['date'], daily['temp sample'],
                          alpha=0.3, color='red', label='Temperature')
    axes[1].set_ylabel('Temperature (C)', fontsize=13, fontweight='bold')
    axes[1].set_xlabel('Date', fontsize=13, fontweight='bold')
    axes[1].grid(True, alpha=0.3)
    axes[1].set_title('Surface Water Temperature', fontweight='bold', fontsize=14)
    axes[1].tick_params(labelsize=11)

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '1_fluoroprobe_timeseries.png', dpi=300, bbox_inches='tight')
    print(f"    Saved: output/1_fluoroprobe_timeseries.png")
    plt.close()

plot_fluoroprobe_depth_profile(fluoroprobe, station='A')
# Timeseries plot removed - wavelength signals should not be labeled as algal groups

# =============================================================================
# 2. LABORATORY CHLOROPHYLL MEASUREMENTS
# =============================================================================
print("\n" + "=" * 80)
print("2. LABORATORY CHLOROPHYLL MEASUREMENTS")
print("=" * 80)

def load_laboratory_chlorophyll():
    """Load laboratory chlorophyll-a measurements (gold standard)"""
    print("\nLoading laboratory chlorophyll data...")
    lab_file = DATA_DIR / 'laboratory_chlorophyll' / 'laboratory_chlorophyll_2017-2023_clean.csv'
    df = pd.read_csv(lab_file, parse_dates=['date'])
    print(f"Loaded {len(df):,} laboratory measurements")
    print(f"Date range: {df['date'].min()} to {df['date'].max()}")
    print(f"chl range: {df['chl'].min():.2f} - {df['chl'].max():.2f} μg/L")
    print(f"Mean: {df['chl'].mean():.2f} μg/L")
    return df

lab_chl = load_laboratory_chlorophyll()

def plot_laboratory_chlorophyll_timeseries(df):
    """Plot laboratory chlorophyll time series"""
    print("\n  Plotting laboratory chlorophyll time series...")

    # Sort by date to ensure proper time series
    df = df.sort_values('date')

    fig, ax = plt.subplots(figsize=(16, 7))

    # Plot as scatter with connecting line
    ax.plot(df['date'], df['chl'], 'o-', linewidth=2, markersize=8,
            color='darkgreen', markerfacecolor='gold', markeredgewidth=2,
            markeredgecolor='darkgreen', alpha=0.8,
            label='Laboratory Chl-a (Station A)')

    # Add horizontal mean line
    mean_chl = df['chl'].mean()
    ax.axhline(mean_chl, color='red', linestyle='--', linewidth=2,
               alpha=0.7, label=f'Mean = {mean_chl:.2f} ug/L')

    ax.set_xlabel('Date', fontsize=13, fontweight='bold')
    ax.set_ylabel('Chlorophyll-a (ug/L)', fontsize=13, fontweight='bold')
    ax.set_title('Lake Kinneret Station A - Laboratory Chlorophyll-a Measurements (2017-2023)',
                 fontsize=15, fontweight='bold')
    ax.legend(loc='best', fontsize=12, framealpha=0.95, edgecolor='black')
    ax.grid(True, alpha=0.4, linestyle='--')
    ax.tick_params(labelsize=11)
    ax.set_ylim(bottom=0)

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '2_laboratory_chlorophyll_timeseries.png', dpi=300, bbox_inches='tight')
    print(f"    Saved: output/2_laboratory_chlorophyll_timeseries.png")
    plt.close()

plot_laboratory_chlorophyll_timeseries(lab_chl)

# =============================================================================
# 3. SATELLITE CHLOROPHYLL RETRIEVALS
# =============================================================================
print("\n" + "=" * 80)
print("3. SATELLITE CHLOROPHYLL RETRIEVALS (SENTINEL-2)")
print("=" * 80)

def load_satellite_chlorophyll(year=2023):
    """Load satellite chlorophyll for a specific year"""
    print(f"\nLoading satellite chlorophyll data for {year}...")

    # Find satellite file for the year
    sat_files = sorted(glob(str(DATA_DIR / 'satellite_chlorophyll' / f'S2_{year}_*_chl_clean.csv')))

    if not sat_files:
        print(f"  Warning: No satellite files found for {year}")
        return None

    df = pd.read_csv(sat_files[0], parse_dates=['date'])
    print(f"Loaded {len(df):,} satellite observations")
    print(f"File: {Path(sat_files[0]).name}")
    print(f"Date range: {df['date'].min()} to {df['date'].max()}")
    print(f"chl range: {df['chl'].min():.2f} - {df['chl'].max():.2f} μg/L")
    return df

sat_chl = load_satellite_chlorophyll(year=2023)

def plot_satellite_chlorophyll_map(df):
    """Plot spatial distribution of satellite chlorophyll"""
    print("\n  Plotting satellite chlorophyll spatial map...")

    # Get a representative date (median)
    dates = sorted(df['date'].unique())
    date_filter = dates[len(dates)//2]
    scene = df[df['date'] == date_filter]

    fig, ax = plt.subplots(figsize=(10, 12))

    scatter = ax.scatter(scene['lon'], scene['lat'], c=scene['chl'],
                        cmap='YlGnBu', s=10, alpha=0.8, vmin=0, vmax=30)
    ax.set_xlabel('Longitude (°E)', fontsize=12)
    ax.set_ylabel('Latitude (°N)', fontsize=12)
    ax.set_title(f'Lake Kinneret - Satellite chl-a (Sentinel-2)\n' +
                 f'{pd.to_datetime(date_filter).date()}',
                 fontsize=14, fontweight='bold')
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

    cbar = plt.colorbar(scatter, ax=ax, pad=0.02)
    cbar.set_label('chl-a (μg/L)', fontsize=12)

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '3_satellite_chlorophyll_map.png', dpi=300, bbox_inches='tight')
    print(f"  Saved: output/3_satellite_chlorophyll_map.png")
    plt.close()

def plot_satellite_chlorophyll_histogram(df):
    """Plot distribution of chlorophyll values"""
    print("\n  Plotting satellite chlorophyll distribution...")

    fig, ax = plt.subplots(figsize=(10, 6))

    ax.hist(df['chl'], bins=50, color='seagreen', alpha=0.7, edgecolor='black')
    ax.axvline(df['chl'].mean(), color='red', linestyle='--', linewidth=2,
               label=f'Mean = {df["chl"].mean():.2f} μg/L')
    ax.axvline(df['chl'].median(), color='orange', linestyle='--', linewidth=2,
               label=f'Median = {df["chl"].median():.2f} μg/L')
    ax.set_xlabel('chl-a (μg/L)', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.set_title('Lake Kinneret - Satellite chl-a Distribution',
                 fontsize=14, fontweight='bold')
    ax.legend(loc='best', fontsize=11)
    ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '3_satellite_chlorophyll_histogram.png', dpi=300, bbox_inches='tight')
    print(f"  Saved: output/3_satellite_chlorophyll_histogram.png")
    plt.close()

if sat_chl is not None:
    plot_satellite_chlorophyll_map(sat_chl)
    plot_satellite_chlorophyll_histogram(sat_chl)

# =============================================================================
# 4. METEOROLOGICAL OBSERVATIONS
# =============================================================================
print("\n" + "=" * 80)
print("4. METEOROLOGICAL OBSERVATIONS (GINOSAR STATION)")
print("=" * 80)

def load_meteorology():
    """Load meteorological data"""
    print("\nLoading meteorological data...")
    met_file = DATA_DIR / 'meteorology' / 'meteorology_ginosar_2017-2023.csv'
    df = pd.read_csv(met_file, parse_dates=['date'])
    print(f"Loaded {len(df):,} observations")
    print(f"Date range: {df['date'].min()} to {df['date'].max()}")
    print(f"Temporal resolution: 10-minute intervals")
    return df

met = load_meteorology()

def plot_meteorology_overview(df):
    """Plot meteorological variables overview"""
    print("\n  Plotting meteorology overview...")

    # Reindex to complete 10-min grid before resampling to daily means.
    # This prevents bias from missing nighttime records (e.g. Jan-Apr 2019
    # power gaps) inflating solar radiation and other daily means.
    full_idx = pd.date_range(df['date'].min(), df['date'].max(), freq='10min')
    df_full = df.set_index('date').reindex(full_idx)
    # Solar radiation: missing nighttime records should be 0, not NaN
    df_full['slrw_avg'] = df_full['slrw_avg'].fillna(0)
    # Wind speed: flag the Jun 2020 stuck-sensor period (8 days at zero) as NaN
    stuck_wind = (df_full.index >= '2020-06-20') & (df_full.index <= '2020-06-30') & (df_full['ws_ms_avg'] <= 0.002)
    df_full.loc[stuck_wind, ['ws_ms_avg', 'ws_ms_min', 'ws_ms_max']] = np.nan
    daily = df_full.resample('D').mean().reset_index().rename(columns={'index': 'date'})

    fig, axes = plt.subplots(4, 1, figsize=(14, 12), sharex=True)

    # Temperature
    axes[0].plot(daily['date'], daily['airtc_avg'], 'r-', linewidth=1.5)
    axes[0].fill_between(daily['date'], daily['airtc_min'], daily['airtc_max'],
                         alpha=0.3, color='red')
    axes[0].set_ylabel('Air Temperature (°C)', fontsize=11)
    axes[0].grid(True, alpha=0.3)
    axes[0].set_title('Lake Kinneret Ginosar Station - Meteorological Observations',
                     fontsize=14, fontweight='bold')
    axes[0].legend(['Daily Mean', 'Daily Range'], loc='best', fontsize=9)

    # Wind speed
    axes[1].plot(daily['date'], daily['ws_ms_avg'], 'b-', linewidth=1.5)
    axes[1].fill_between(daily['date'], daily['ws_ms_min'], daily['ws_ms_max'],
                         alpha=0.3, color='blue')
    axes[1].set_ylabel('Wind Speed (m/s)', fontsize=11)
    axes[1].grid(True, alpha=0.3)
    axes[1].legend(['Daily Mean', 'Daily Range'], loc='best', fontsize=9)

    # Solar radiation
    axes[2].plot(daily['date'], daily['slrw_avg'], 'orange', linewidth=1.5)
    axes[2].set_ylabel('Solar Radiation (W/m²)', fontsize=11)
    axes[2].grid(True, alpha=0.3)

    # Relative humidity
    axes[3].plot(daily['date'], daily['rh'], 'g-', linewidth=1.5)
    axes[3].set_ylabel('Relative Humidity (%)', fontsize=11)
    axes[3].set_xlabel('Date', fontsize=12)
    axes[3].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '4_meteorology_overview.png', dpi=300, bbox_inches='tight')
    print(f"  Saved: output/4_meteorology_overview.png")
    plt.close()

plot_meteorology_overview(met)

# =============================================================================
# 5. MACHINE LEARNING PHYTOPLANKTON PREDICTIONS
# =============================================================================
print("\n" + "=" * 80)
print("5. MACHINE LEARNING PHYTOPLANKTON PREDICTIONS")
print("=" * 80)

def load_ml_predictions():
    """Load ML phytoplankton predictions"""
    print("\nLoading ML predictions...")
    ml_file = DATA_DIR / 'ml_predictions' / 'phytoplankton_predictions_2017_2023.csv'
    df = pd.read_csv(ml_file, parse_dates=['date'])
    print(f"Loaded {len(df):,} predictions")
    print(f"Stations: {sorted(df['station'].unique())}")
    print(f"Phytoplankton groups: {sorted(df['group_name'].unique())}")
    print(f"Date range: {df['date'].min()} to {df['date'].max()}")
    return df

ml_pred = load_ml_predictions()

def plot_ml_predictions_by_group(df, station='A'):
    """Plot ML predictions for all phytoplankton groups"""
    print(f"\n  Plotting ML predictions for Station {station}...")

    # Filter for surface waters (<2m)
    surface = df[(df['station'] == station) & (df['depth'] < 2)].copy()

    # Group by date and phytoplankton group
    daily = surface.groupby(['date', 'group_name'])['predicted_biomass_ug_ml'].mean().reset_index()

    fig, axes = plt.subplots(5, 1, figsize=(14, 12), sharex=True)

    groups = ['Cyanobacteria', 'Cryptophyta', 'Diatoms', 'Dinoflagellates', 'Green Algae']
    colors = ['blue', 'purple', 'brown', 'red', 'green']

    for i, (group, color) in enumerate(zip(groups, colors)):
        group_data = daily[daily['group_name'] == group]
        axes[i].plot(group_data['date'], group_data['predicted_biomass_ug_ml'],
                    color=color, linewidth=1.5, alpha=0.7)
        axes[i].set_ylabel(f'{group}\n(μg/mL)', fontsize=10)
        axes[i].grid(True, alpha=0.3)
        axes[i].set_ylim(bottom=0)

    axes[-1].set_xlabel('Date', fontsize=12)
    fig.suptitle(f'Lake Kinneret Station {station} - ML Predicted Phytoplankton Biomass\n' +
                 'SVR model predictions from fluoroprobe measurements',
                 fontsize=14, fontweight='bold', y=0.995)

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '5_ml_predictions_by_group.png', dpi=300, bbox_inches='tight')
    print(f"  Saved: output/5_ml_predictions_by_group.png")
    plt.close()

def plot_ml_predictions_stacked(df, station='A'):
    """Plot stacked area chart of phytoplankton community composition"""
    print(f"\n  Plotting ML predictions stacked area chart...")

    # Filter for surface waters and aggregate
    surface = df[(df['station'] == station) & (df['depth'] < 2)].copy()

    # Pivot to get groups as columns
    pivot = surface.pivot_table(
        index='date',
        columns='group_name',
        values='predicted_biomass_ug_ml',
        aggfunc='mean'
    ).fillna(0)

    fig, ax = plt.subplots(figsize=(14, 7))

    ax.stackplot(pivot.index,
                 pivot['Cyanobacteria'], pivot['Cryptophyta'], pivot['Diatoms'],
                 pivot['Dinoflagellates'], pivot['Green Algae'],
                 labels=['Cyanobacteria', 'Cryptophyta', 'Diatoms',
                        'Dinoflagellates', 'Green Algae'],
                 colors=['blue', 'purple', 'brown', 'red', 'green'],
                 alpha=0.7)

    ax.set_xlabel('Date', fontsize=12)
    ax.set_ylabel('Predicted Biomass (μg/mL)', fontsize=12)
    ax.set_title(f'Lake Kinneret Station {station} - Phytoplankton Community Composition\n' +
                 'ML predicted biomass from SVR model',
                 fontsize=14, fontweight='bold')
    ax.legend(loc='upper left', fontsize=10)
    ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '5_ml_predictions_stacked.png', dpi=300, bbox_inches='tight')
    print(f"  Saved: output/5_ml_predictions_stacked.png")
    plt.close()

plot_ml_predictions_by_group(ml_pred, station='A')
plot_ml_predictions_stacked(ml_pred, station='A')

# =============================================================================
# 6. VALIDATION DATA
# =============================================================================
print("\n" + "=" * 80)
print("6. VALIDATION DATA (MATCHUPS & STATISTICS)")
print("=" * 80)

def load_fluoroprobe_validation_matchups():
    """Load satellite-fluoroprobe validation matchups"""
    print("\nLoading fluoroprobe validation matchups...")
    val_file = DATA_DIR / 'validation' / 'fluoroprobe_validation_matchups_clean.csv'
    df = pd.read_csv(val_file, parse_dates=['fp_date', 'sat_date'])
    print(f"Loaded {len(df):,} matchup pairs")
    print(f"Stations: {sorted(df['station'].unique())}")
    print(f"Date range: {df['fp_date'].min()} to {df['fp_date'].max()}")
    return df

def load_laboratory_fluoroprobe_matchups():
    """Load laboratory-fluoroprobe validation matchups"""
    print("\nLoading laboratory-fluoroprobe matchups...")
    val_file = DATA_DIR / 'validation' / 'laboratory_fluoroprobe_matchups.csv'
    df = pd.read_csv(val_file, parse_dates=['lab_date', 'fp_date'])
    print(f"Loaded {len(df):,} matchup pairs (Station A only)")
    print(f"Date range: {df['lab_date'].min()} to {df['lab_date'].max()}")
    return df

def load_laboratory_satellite_matchups():
    """Load laboratory-satellite validation matchups"""
    print("\nLoading laboratory-satellite matchups...")
    val_file = DATA_DIR / 'validation' / 'laboratory_satellite_matchups.csv'
    df = pd.read_csv(val_file, parse_dates=['lab_date', 'sat_date'])
    print(f"Loaded {len(df):,} matchup pairs (Station A only)")
    print(f"Date range: {df['lab_date'].min()} to {df['lab_date'].max()}")
    return df

def load_validation_statistics():
    """Load validation statistics"""
    print("\nLoading validation statistics...")

    # Overall statistics
    overall_file = DATA_DIR / 'validation' / 'overall_validation_statistics.csv'
    overall = pd.read_csv(overall_file)
    print(f"Overall: n={overall['N'].values[0]}, R²={overall['R²'].values[0]:.3f}, " +
          f"RMSE={overall['RMSE'].values[0]:.2f} μg/L")

    # Per-station statistics
    station_file = DATA_DIR / 'validation' / 'fluoroprobe_validation_statistics_clean.csv'
    station = pd.read_csv(station_file)
    print(f"Per-station stats: {len(station)} stations")

    # Seasonal statistics
    seasonal_file = DATA_DIR / 'validation' / 'seasonal_validation_statistics_clean.csv'
    seasonal = pd.read_csv(seasonal_file)
    print(f"Seasonal stats: {len(seasonal)} seasons")

    return overall, station, seasonal

fp_matchups = load_fluoroprobe_validation_matchups()
lab_fp_matchups = load_laboratory_fluoroprobe_matchups()
lab_sat_matchups = load_laboratory_satellite_matchups()
overall_stats, station_stats, seasonal_stats = load_validation_statistics()

def plot_satellite_fluoroprobe_validation(matchups, stats):
    """Plot satellite vs fluoroprobe validation scatter"""
    print("\n  Plotting satellite-fluoroprobe validation...")

    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
    axes = axes.flatten()

    stations = sorted(matchups['station'].unique())
    colors = {'A': 'red', 'D': 'orange', 'G': 'green', 'H': 'blue', 'K': 'purple'}

    # Plot each station
    for i, station in enumerate(stations):
        ax = axes[i]
        station_data = matchups[matchups['station'] == station]
        station_stat = stats[stats['Station'] == station].iloc[0]

        # Scatter plot
        ax.scatter(station_data['fp_chl'], station_data['sat_chl'],
                  alpha=0.6, s=30, color=colors[station], edgecolors='black', linewidth=0.5)

        # 1:1 line
        max_val = max(station_data['fp_chl'].max(), station_data['sat_chl'].max())
        ax.plot([0, max_val], [0, max_val], 'k--', linewidth=1.5, label='1:1 line')

        # Statistics text
        stats_text = (f"n = {int(station_stat['N'])}\n"
                     f"R² = {station_stat['R²']:.3f}\n"
                     f"RMSE = {station_stat['RMSE']:.2f} μg/L\n"
                     f"Bias = {station_stat['Bias']:.2f} μg/L")
        ax.text(0.05, 0.95, stats_text, transform=ax.transAxes,
               fontsize=9, verticalalignment='top',
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

        ax.set_xlabel('Fluoroprobe Chl-a (μg/L)', fontsize=10)
        ax.set_ylabel('Satellite Chl-a (μg/L)', fontsize=10)
        ax.set_title(f'Station {station}', fontweight='bold', fontsize=11)
        ax.grid(True, alpha=0.3)
        ax.set_xlim(0, max_val * 1.05)
        ax.set_ylim(0, max_val * 1.05)
        ax.set_aspect('equal')

    # Hide extra subplot
    axes[5].axis('off')

    fig.suptitle('Lake Kinneret - Satellite vs Fluoroprobe chl-a Validation\n' +
                 '5-Station Analysis (2017-2023)',
                 fontsize=14, fontweight='bold')

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '6_satellite_fluoroprobe_validation.png', dpi=300, bbox_inches='tight')
    print(f"  Saved: output/6_satellite_fluoroprobe_validation.png")
    plt.close()

def plot_laboratory_validation(lab_fp, lab_sat):
    """Plot laboratory validation comparisons"""
    print("\n  Plotting laboratory validation comparisons...")

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Laboratory vs Fluoroprobe
    axes[0].scatter(lab_fp['lab_chl'], lab_fp['fp_chl'],
                   alpha=0.6, s=40, color='steelblue', edgecolors='black', linewidth=0.5)
    max_val = max(lab_fp['lab_chl'].max(), lab_fp['fp_chl'].max())
    axes[0].plot([0, max_val], [0, max_val], 'k--', linewidth=2, label='1:1 line')

    from scipy import stats as scipy_stats
    r, _ = scipy_stats.pearsonr(lab_fp['lab_chl'], lab_fp['fp_chl'])
    rmse = np.sqrt(np.mean((lab_fp['fp_chl'] - lab_fp['lab_chl'])**2))

    stats_text = f"n = {len(lab_fp)}\nR² = {r**2:.3f}\nRMSE = {rmse:.2f} μg/L"
    axes[0].text(0.05, 0.95, stats_text, transform=axes[0].transAxes,
                fontsize=10, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

    axes[0].set_xlabel('Laboratory Chl-a (μg/L)', fontsize=11)
    axes[0].set_ylabel('Fluoroprobe Chl-a (μg/L)', fontsize=11)
    axes[0].set_title('Laboratory vs Fluoroprobe', fontweight='bold', fontsize=12)
    axes[0].grid(True, alpha=0.3)
    axes[0].set_aspect('equal')

    # Laboratory vs Satellite
    axes[1].scatter(lab_sat['lab_chl'], lab_sat['sat_chl'],
                   alpha=0.6, s=40, color='darkgreen', edgecolors='black', linewidth=0.5)
    max_val = max(lab_sat['lab_chl'].max(), lab_sat['sat_chl'].max())
    axes[1].plot([0, max_val], [0, max_val], 'k--', linewidth=2, label='1:1 line')

    r, _ = scipy_stats.pearsonr(lab_sat['lab_chl'], lab_sat['sat_chl'])
    rmse = np.sqrt(np.mean((lab_sat['sat_chl'] - lab_sat['lab_chl'])**2))

    stats_text = f"n = {len(lab_sat)}\nR² = {r**2:.3f}\nRMSE = {rmse:.2f} μg/L"
    axes[1].text(0.05, 0.95, stats_text, transform=axes[1].transAxes,
                fontsize=10, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

    axes[1].set_xlabel('Laboratory Chl-a (μg/L)', fontsize=11)
    axes[1].set_ylabel('Satellite Chl-a (μg/L)', fontsize=11)
    axes[1].set_title('Laboratory vs Satellite', fontweight='bold', fontsize=12)
    axes[1].grid(True, alpha=0.3)
    axes[1].set_aspect('equal')

    fig.suptitle('Lake Kinneret Station A - Laboratory Validation\n' +
                 'Gold standard fluorometric chlorophyll-a',
                 fontsize=14, fontweight='bold')

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '6_laboratory_validation.png', dpi=300, bbox_inches='tight')
    print(f"  Saved: output/6_laboratory_validation.png")
    plt.close()

def plot_seasonal_validation(stats):
    """Plot seasonal validation statistics"""
    print("\n  Plotting seasonal validation statistics...")

    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    seasons = stats['season'].values
    x_pos = np.arange(len(seasons))

    # R²
    axes[0, 0].bar(x_pos, stats['r2'], color='steelblue', alpha=0.7, edgecolor='black')
    axes[0, 0].set_xticks(x_pos)
    axes[0, 0].set_xticklabels(seasons)
    axes[0, 0].set_ylabel('R²', fontsize=11)
    axes[0, 0].set_title('Coefficient of Determination', fontweight='bold')
    axes[0, 0].grid(True, alpha=0.3, axis='y')
    axes[0, 0].set_ylim(0, 1)

    # RMSE
    axes[0, 1].bar(x_pos, stats['rmse'], color='coral', alpha=0.7, edgecolor='black')
    axes[0, 1].set_xticks(x_pos)
    axes[0, 1].set_xticklabels(seasons)
    axes[0, 1].set_ylabel('RMSE (μg/L)', fontsize=11)
    axes[0, 1].set_title('Root Mean Square Error', fontweight='bold')
    axes[0, 1].grid(True, alpha=0.3, axis='y')

    # Bias
    axes[1, 0].bar(x_pos, stats['bias'], color='lightgreen', alpha=0.7, edgecolor='black')
    axes[1, 0].axhline(0, color='black', linestyle='--', linewidth=1)
    axes[1, 0].set_xticks(x_pos)
    axes[1, 0].set_xticklabels(seasons)
    axes[1, 0].set_ylabel('Bias (μg/L)', fontsize=11)
    axes[1, 0].set_title('Mean Bias', fontweight='bold')
    axes[1, 0].grid(True, alpha=0.3, axis='y')

    # Sample size
    axes[1, 1].bar(x_pos, stats['n'], color='purple', alpha=0.7, edgecolor='black')
    axes[1, 1].set_xticks(x_pos)
    axes[1, 1].set_xticklabels(seasons)
    axes[1, 1].set_ylabel('Number of matchups', fontsize=11)
    axes[1, 1].set_title('Sample Size', fontweight='bold')
    axes[1, 1].grid(True, alpha=0.3, axis='y')

    # Add values on bars
    for ax in axes.flatten():
        for container in ax.containers:
            ax.bar_label(container, fmt='%.2f', padding=3, fontsize=9)

    fig.suptitle('Lake Kinneret - Seasonal Validation Statistics\n' +
                 'Satellite vs Fluoroprobe chl-a (5 stations)',
                 fontsize=14, fontweight='bold')

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '6_seasonal_validation.png', dpi=300, bbox_inches='tight')
    print(f"  Saved: output/6_seasonal_validation.png")
    plt.close()

plot_satellite_fluoroprobe_validation(fp_matchups, station_stats)
plot_laboratory_validation(lab_fp_matchups, lab_sat_matchups)
plot_seasonal_validation(seasonal_stats)

# =============================================================================
# SUMMARY
# =============================================================================
print("\n" + "=" * 80)
print("EXAMPLE ANALYSIS COMPLETE!")
print("=" * 80)
print("\nGenerated plots (saved to output/ directory):")
print("  1. Fluoroprobe:")
print("     - output/1_fluoroprobe_depth_profile.png")
print("     - output/1_fluoroprobe_timeseries.png")
print("  2. Laboratory chlorophyll:")
print("     - output/2_laboratory_chlorophyll_timeseries.png")
print("  3. Satellite chlorophyll:")
print("     - output/3_satellite_chlorophyll_map.png")
print("     - output/3_satellite_chlorophyll_histogram.png")
print("  4. Meteorology:")
print("     - output/4_meteorology_overview.png")
print("  5. ML predictions:")
print("     - output/5_ml_predictions_by_group.png")
print("     - output/5_ml_predictions_stacked.png")
print("  6. Validation:")
print("     - output/6_satellite_fluoroprobe_validation.png")
print("     - output/6_laboratory_validation.png")
print("     - output/6_seasonal_validation.png")
print("\n" + "=" * 80)
print("Dataset Information:")
print("  - Full documentation: README.md")
print("  - Technical metadata: METADATA.md")
print("  - Variable definitions: DATA_DICTIONARY.csv")
print("  - Validation details: validation/README_VALIDATION.md")
print("  - Citation: CITATION.cff")
print("  - License: CC BY 4.0 (LICENSE.txt)")
print("=" * 80)
